The MNIST dataset is a popular dataset for data science and machine learning tasks. It consists of handwritten digits.

We’ll be working with a subset of 1000 images from this dataset.

library(tidyverse)
## Warning: replacing previous import 'vctrs::data_frame' by 'tibble::data_frame'
## when loading 'dplyr'
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.2     ✓ purrr   0.3.4
## ✓ tibble  3.0.4     ✓ dplyr   1.0.1
## ✓ tidyr   1.1.0     ✓ stringr 1.4.0
## ✓ readr   1.3.1     ✓ forcats 0.5.0
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(plotly)
## 
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
## 
##     last_plot
## The following object is masked from 'package:stats':
## 
##     filter
## The following object is masked from 'package:graphics':
## 
##     layout
mnist <- read_csv('data/mnist.csv', n_max = 5000)
## Parsed with column specification:
## cols(
##   .default = col_double()
## )
## See spec(...) for full column specifications.
mnist_umap <- readRDS('data/mnist_umap.RDS')
mnist_umap %>% 
  as_tibble() %>% 
  ggplot(aes(x = V1, y = V2)) +
  geom_point(size = 0.25)
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
## Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.

First, let’s cluster the mnist data. Let’s first try using k-means.

set.seed(123)

max_clusters = 20

ss = vector(length = max_clusters)
for (i in 1:max_clusters){
  clusters <- mnist_umap %>%
    kmeans(centers = i) 
  ss[i] <- clusters$tot.withinss
}

wss <- tibble(clusters = 1:max_clusters, wss = ss)

wss %>% ggplot() +
  geom_point(aes(x = clusters, y = wss)) +
  geom_line(aes(x = clusters, y = wss)) +
  scale_x_continuous(breaks = 1:max_clusters) +
  labs(title = "K-Means Clustering Results", x = "Number of Clusters", y = "Within-Cluster Total Sum of Squares")

Based on the plot, it appears that 4 might be a good number of clusters to use.

mnist_kmeans <- mnist_umap %>% 
  kmeans(centers = 4)

mnist_umap %>% 
  as_tibble() %>% 
  mutate(cluster = as.factor(mnist_kmeans$cluster)) %>% 
  ggplot(aes(x = V1, y = V2, color = cluster)) +
  geom_point(size = 0.25) +
  guides(colour = guide_legend(override.aes = list(size=2)))

However, it appears that there are at least 6 distinct clusters.

set.seed(123)

mnist_kmeans <- mnist_umap %>% 
  kmeans(centers = 6)

mnist_umap %>% 
  as_tibble() %>% 
  mutate(cluster = as.factor(mnist_kmeans$cluster)) %>% 
  ggplot(aes(x = V1, y = V2, color = cluster)) +
  geom_point(size = 0.25) +
  guides(colour = guide_legend(override.aes = list(size=2)))

But, k-means does not do a good job of identifying what look like the natural clusters in our dataset. This has to do with the fact that k-means is looking for identially-sized spherical clusters.

Let’s look at an alternative approach - density-based clustering. Rather that looking for spherical clusters, density-based clustering looks for high-density areas separated by lower-density areas.

We’ll specifically be using HDBSCAN, which is a hierarchical clustering method. We do not have to identify the number of clusters ahead of time. Rather, this is decided by the algorithm.

library(dbscan)

set.seed(123)
mnist_dbscan <- mnist_umap %>% 
  hdbscan(minPts = 30)

mnist_umap %>% 
  as_tibble() %>% 
  mutate(cluster = as.factor(mnist_dbscan$cluster)) %>% 
  ggplot(aes(x = V1, y = V2, color = cluster)) +
  geom_point(size = 0.25) +
  guides(colour = guide_legend(override.aes = list(size=2)))

We end up with what looks to be a more sensible clustering.

Another advantage of DBSCAN is that it can identify potential outlier points. K-means must assign every point to a cluster; whereas, DBSCAN can label points as outliers. It does this on the basis of a GLOSH score, which compares the density at each observation to the nearby density.

p <- mnist_umap %>% 
  as_tibble() %>% 
  mutate(outlier_score = mnist_dbscan$outlier_scores) %>% 
  ggplot(aes(x = V1, y = V2, color = outlier_score)) +
  geom_point(size = 0.5) +
  scale_color_viridis_c()

ggplotly(p)

We can highlight the most likely outliers.

potential_outliers <- (-mnist_dbscan$outlier_scores %>% order())[1:10]

p <- mnist_umap %>% 
  as_tibble() %>% 
  mutate(cluster = as.factor(mnist_kmeans$cluster)) %>% 
  ggplot(aes(x = V1, y = V2)) +
  geom_point(size = 0.25) + 
  geom_point(data = mnist_umap[potential_outliers, ] %>% as_tibble(), aes(x = V1, y = V2), fill = 'orange', size = 2, shape = 21)

ggplotly(p)

Finally, we can see what these potential outliers look like.

im <- matrix(mnist[potential_outliers[1],] %>% select(-c(label)) %>% unlist(), nrow = 28, byrow = FALSE)

for (i in 1:nrow(im)){
  im[i,] = rev(im[i,])
}
par(pty = "s")
image(1:28, 1:28, im, col = gray(rev((0:255)/255)))

There are other metrics useful for detecting outliers, including k-nearest neighbors deistance, local outlier factor (lof) and isolation scores, based on isolation forests.